Batchnormgrad
逐元素计算加法梯度
计算批标准化 (Batch Normalization) 的梯度。
该算子计算损失函数 L 分别对输入 x、缩放因子 scale (γ) 和偏置 bias (β) 的梯度。其中 bias 的梯度为 dbias。
\[\begin{split}dscale(\gamma) &= \sum_{i=1}^{m} dy_i \cdot \hat{x}_i \\
dbias(\beta) &= \sum_{i=1}^{m} dy_i\end{split}\]
\[dx_i = \frac{\gamma}{m\sqrt{\sigma^2 + \epsilon}} \left[ m \cdot dy_i - \sum_{j=1}^{m}dy_j - \hat{x}_i \sum_{j=1}^{m}dy_j \hat{x}_j \right]\]
其中 \(m\) 是批处理大小 (batch),\(\hat{x}\) 是归一化后的 \(x\)。
- 输入:
x - 前向传播时的输入张量。
dy - 来自后一层的上游梯度。
mean - 前向传播时计算的均值。
invar - 前向传播时计算的逆方差 (1 / sqrt(variance + epsilon))。
scale - 前向传播时使用的缩放因子 (gamma, γ)。
batch - 批处理大小。
channel - 通道数。
is_train - 是否为训练模式。梯度计算通常在训练时进行。
core_mask - 核掩码。
- 输出:
dx - 对输入 x 的梯度。
dbias - 对偏置 bias (β) 的梯度。
dscale - 对缩放因子 scale (γ) 的梯度。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持fp32
MT7004 支持fp16, fp32
共享存储版本:
-
void fp_batchnormgrad_s(float *x, float *dy, float *mean, float *invar, float *scale, int batch, int channel, int is_train, float *dx, float *dbias, float *dscale, int core_mask)
-
void hp_batchnormgrad_s(half *x, half *dy, half *mean, half *invar, half *scale, int batch, int channel, int is_train, half *dx, half *dbias, half *dscale, int core_mask)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <batchnormgrad.h>
4int main(int argc, char* argv[]) {
5 float *x = (float *)0xA0000000; // forward input x
6 float *dy = (float *)0xB0000000; // upstream gradient dy
7 float *mean = (float *)0xC0000000; // forward mean
8 float *invar = (float *)0xD0000000; // forward inverse variance
9 float *scale = (float *)0xE0000000; // forward scale (gamma)
10
11 float *dx = (float *)0xA1000000; // output gradient dx
12 float *dbias = (float *)0xB1000000; // output gradient dbias
13 float *dscale = (float *)0xC1000000; // output gradient dscale
14
15 int batch = 4;
16 int channel = 64;
17 int is_train = true;
18 int core_mask = 0xff;
19
20 fp_batchnormgrad_s(x, dy, mean, invar, scale, batch, channel, is_train, dx, dbias, dscale, core_mask);
21 return 0;
22}
私有存储版本:
-
void fp_batchnormgrad_p(float *x, float *dy, float *mean, float *invar, float *scale, int batch, int channel, int is_train, float *dx, float *dbias, float *dscale)
-
void hp_batchnormgrad_p(half *x, half *dy, half *mean, half *invar, half *scale, int batch, int channel, int is_train, half *dx, half *dbias, half *dscale)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <batchnormgrad.h>
4int main(int argc, char* argv[]) {
5 float *x = (float *)0x10000000; // forward input x in L2 space
6 float *dy = (float *)0x10100000; // upstream gradient dy
7 float *mean = (float *)0x10200000; // forward mean
8 float *invar = (float *)0x10300000; // forward inverse variance
9 float *scale = (float *)0x10400000; // forward scale (gamma)
10
11 float *dx = (float *)0x10500000; // output gradient dx
12 float *dbias = (float *)0x10600000; // output gradient dbias
13 float *dscale = (float *)0x10700000; // output gradient dscale
14
15 int batch = 4;
16 int channel = 32;
17 int is_train = true;
18
19 fp_batchnormgrad_p(x, dy, mean, invar, scale, batch, channel, is_train, dx, dbias, dscale);
20 return 0;
21}